{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fnpzZdc01IvW"
   },
   "source": [
    "## Description\n",
    "This notebook generates the data for the shortest problem. We generate two versions of the same data: one where graphs are vertex-weighted and one where graphs are edge-weighted. The data consists of:\n",
    "  - pairs (d,x) where d is a five dim contextual vector and x encodes the shortest path. In the **v**ertex version of the data set, the path is represented via a binary $m\\times m$ matrix with ones indicating the vertices traversed by the path. In the **e**dge version, the data set is represented via a binary $|E|$ vector, where $|E|$ is the number of edges. \n",
    "  - W, a d-by-m^2 matrix which maps contexts to vertex weights.\n",
    "  - A, the vertex-edge adjacency matrix. b a vector\n",
    "  - Some useful metadata, e.g. the number of edges, and an edge list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "iGYw-_NI1Iva"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "\n",
    "import sys\n",
    "root_dir = \"../../\"\n",
    "sys.path.append(root_dir)\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from torch.utils.data import DataLoader\n",
    "from src.utils import edge_to_node, create_shortest_path_data\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "ssvGZ8201Ivc"
   },
   "outputs": [],
   "source": [
    "## Fix some parameters\n",
    "train_size = 200\n",
    "test_size = 200\n",
    "context_size = 5\n",
    "# m_max = 25 # generate m-by-m graphs, with m in increments of 5, between 5 and m_max inclusive. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZBr2WNgT1Ivc",
    "outputId": "35bbc070-b068-4b1f-88fd-0fd092622765"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start building dataset for m = 30\n",
      "Created data! Now saving\n",
      "Finished building dataset for m = 30 , Time =  18.93810510635376  seconds\n"
     ]
    }
   ],
   "source": [
    "# Loop over graph sizes\n",
    "# grid_array = [5,10,15,20,25,50,100]\n",
    "# for m in grid_array:\n",
    "m = 30\n",
    "start_time = time.time()\n",
    "print('Start building dataset for m = ' + str(m))\n",
    "train_dataset_v, test_dataset_v, train_dataset_e, test_dataset_e, WW, A, b, num_edges, Edge_list = create_shortest_path_data(m, train_size, test_size, context_size)\n",
    "end_time = time.time()\n",
    "print('Created data! Now saving')\n",
    "state = {\n",
    "      'WW': WW,\n",
    "      'train_dataset_v': train_dataset_v,\n",
    "      'test_dataset_v': test_dataset_v,\n",
    "      'train_dataset_e': train_dataset_e,\n",
    "      'test_dataset_e': test_dataset_e,\n",
    "      'm': m,\n",
    "      'A':A,\n",
    "      'b':b,\n",
    "      'num_edges': num_edges,\n",
    "      'Edge_list': Edge_list,\n",
    "      }\n",
    "save_dir = './shortest_path_data/' ############# CHANGE THIS BACK TO shortest_path_data\n",
    "state_path = save_dir + 'Shortest_Path_training_data' + str(m) +'.pth'\n",
    "torch.save(state, state_path)\n",
    "print('Finished building dataset for m = ' + str(m), ', Time = ', end_time - start_time, ' seconds')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_P_vred41Ive"
   },
   "source": [
    "## Examine some paths\n",
    "The next two blocks are test code to see if the new data generating functions are working."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "13vOhqn21Ive"
   },
   "outputs": [],
   "source": [
    "## Examine some paths\n",
    "state = torch.load('./shortest_path_data/Shortest_Path_training_data30.pth')\n",
    "\n",
    "## Extract data from state\n",
    "train_dataset_e = state['train_dataset_e']\n",
    "test_dataset_e = state['test_dataset_e']\n",
    "train_dataset_v = state['train_dataset_v']\n",
    "test_dataset_v = state['test_dataset_v']\n",
    "m = state[\"m\"]\n",
    "A = state[\"A\"].float()\n",
    "b = state[\"b\"].float()\n",
    "num_edges = state[\"num_edges\"]\n",
    "Edge_list = state[\"Edge_list\"]\n",
    "Edge_list_torch = torch.tensor(Edge_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HUjJnKwh1Ivj"
   },
   "source": [
    "## Visualization\n",
    "The next block of code visually verifies that the data in **v**ertex format and **e**dge format are the same. For any choise of `i`, the two paths illustrated should be the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "boPxUzcp1Ivk",
    "outputId": "eaf44886-ca1d-4201-f21f-d6f17fec9a23"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZMElEQVR4nO3df0xV9/3H8df1B1dt4VJEuNyJDm3VrSrNXGXE1tlJBJYYrf6hbf/Qxmh02ExZ14am1botYbGJM22Y/rPJllTtTKqm5jsbxYLpBi5SjTHb+AphU8MPVzPuRaxI5fP9o9/e7VatXryXN/fyfCQn8d574L5PTuOzx3v44HHOOQEAMMhGWA8AABieCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAxynqAr+rv71dbW5tSU1Pl8XisxwEARMk5p+7ubgUCAY0YcffrnCEXoLa2NuXm5lqPAQB4QJcuXdLEiRPv+vqQC1Bqaqok6Z+ffFNpD9/fvxA+O21WPEcCAEThc/XpY/1P+O/zu4lbgKqqqvTWW2+po6ND+fn5eueddzR37tx7ft2X/+yW9vAIpaXeX4BGeUY/0KwAgBj6/xVG7/UxSlxuQnjvvfdUXl6urVu36pNPPlF+fr6Ki4t15cqVeLwdACABxSVAO3bs0Nq1a/Xiiy/q29/+tnbv3q1x48bpt7/9bTzeDgCQgGIeoJs3b6qxsVFFRUX/eZMRI1RUVKT6+vrb9u/t7VUoFIrYAADJL+YB+vTTT3Xr1i1lZ2dHPJ+dna2Ojo7b9q+srJTP5wtv3AEHAMOD+Q+iVlRUKBgMhrdLly5ZjwQAGAQxvwsuMzNTI0eOVGdnZ8TznZ2d8vv9t+3v9Xrl9XpjPQYAYIiL+RVQSkqK5syZo5qamvBz/f39qqmpUWFhYazfDgCQoOLyc0Dl5eVatWqVvvvd72ru3LnauXOnenp69OKLL8bj7QAACSguAVqxYoX+9a9/acuWLero6NATTzyho0eP3nZjAgBg+PI455z1EP8tFArJ5/Pp3/875b5XQigOPBHfoQAA9+1z16daHVYwGFRaWtpd9zO/Cw4AMDwRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAiZgH6M0335TH44nYZsyYEeu3AQAkuFHx+KaPP/64jh8//p83GRWXtwEAJLC4lGHUqFHy+/3x+NYAgCQRl8+ALly4oEAgoClTpuiFF17QxYsX77pvb2+vQqFQxAYASH4xD1BBQYGqq6t19OhR7dq1S62trXr66afV3d19x/0rKyvl8/nCW25ubqxHAgAMQR7nnIvnG3R1dWny5MnasWOH1qxZc9vrvb296u3tDT8OhULKzc3Vv/93itJS76+PxYEnYjUuAOABfe76VKvDCgaDSktLu+t+cb87ID09XdOmTVNzc/MdX/d6vfJ6vfEeAwAwxMT954CuXbumlpYW5eTkxPutAAAJJOYBevnll1VXV6d//OMf+vOf/6xnn31WI0eO1HPPPRfrtwIAJLCY/xPc5cuX9dxzz+nq1auaMGGCnnrqKTU0NGjChAmxfisAQAKLeYD2798f628JAEhCrAUHADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADARNx/H9Bg+LDtrPUICYFf3AdgKOEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwMWQXI3122iyN8oy2HmNIi3YR1oEs2soCpgDihSsgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJobsWnC4t2jXaRvIWnAD+ZqhhvXsgKGJKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmWAtuGEmWNdGiXZ8uGdazw/1Jlv/GhwuugAAAJggQAMBE1AE6efKkFi9erEAgII/Ho0OHDkW87pzTli1blJOTo7Fjx6qoqEgXLlyI1bwAgCQRdYB6enqUn5+vqqqqO76+fft2vf3229q9e7dOnTqlhx56SMXFxbpx48YDDwsASB5R34RQWlqq0tLSO77mnNPOnTv1+uuva8mSJZKk3//+98rOztahQ4e0cuXKB5sWAJA0YvoZUGtrqzo6OlRUVBR+zufzqaCgQPX19Xf8mt7eXoVCoYgNAJD8Yhqgjo4OSVJ2dnbE89nZ2eHXvqqyslI+ny+85ebmxnIkAMAQZX4XXEVFhYLBYHi7dOmS9UgAgEEQ0wD5/X5JUmdnZ8TznZ2d4de+yuv1Ki0tLWIDACS/mAYoLy9Pfr9fNTU14edCoZBOnTqlwsLCWL4VACDBRX0X3LVr19Tc3Bx+3NraqrNnzyojI0OTJk3Spk2b9Itf/EKPPfaY8vLy9MYbbygQCGjp0qWxnBsAkOCiDtDp06f1zDPPhB+Xl5dLklatWqXq6mq98sor6unp0bp169TV1aWnnnpKR48e1ZgxY2I3NQAg4Xmcc856iP8WCoXk8/m0QEs0yjPaehwAhqJdSJbFSIeGz12fanVYwWDwaz/XN78LDgAwPBEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAR9WKkADBURbt2nMT6cZa4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCteAADFnRrtM2kLXgYIcrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADARdYBOnjypxYsXKxAIyOPx6NChQxGvr169Wh6PJ2IrKSmJ1bwAgCQRdYB6enqUn5+vqqqqu+5TUlKi9vb28LZv374HGhIAkHxGRfsFpaWlKi0t/dp9vF6v/H7/gIcCACS/uHwGVFtbq6ysLE2fPl0bNmzQ1atX4/E2AIAEFvUV0L2UlJRo2bJlysvLU0tLi1577TWVlpaqvr5eI0eOvG3/3t5e9fb2hh+HQqFYjwQAGIJiHqCVK1eG/zxr1izNnj1bU6dOVW1trRYuXHjb/pWVldq2bVusxwAADHFxvw17ypQpyszMVHNz8x1fr6ioUDAYDG+XLl2K90gAgCEg5ldAX3X58mVdvXpVOTk5d3zd6/XK6/XGewwAwBATdYCuXbsWcTXT2tqqs2fPKiMjQxkZGdq2bZuWL18uv9+vlpYWvfLKK3r00UdVXFwc08EBAIkt6gCdPn1azzzzTPhxeXm5JGnVqlXatWuXzp07p9/97nfq6upSIBDQokWL9POf/5yrHABAhKgDtGDBAjnn7vr6hx9++EADAQCGB9aCAwCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMjLIeAAAsfdh2Nqr9iwNPxGWO4YgrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABIuRAkgaA1koNNrFSBE7XAEBAExEFaDKyko9+eSTSk1NVVZWlpYuXaqmpqaIfW7cuKGysjKNHz9eDz/8sJYvX67Ozs6YDg0ASHxRBaiurk5lZWVqaGjQsWPH1NfXp0WLFqmnpye8z+bNm/XBBx/owIEDqqurU1tbm5YtWxbzwQEAiS2qz4COHj0a8bi6ulpZWVlqbGzU/PnzFQwG9Zvf/EZ79+7VD37wA0nSnj179K1vfUsNDQ363ve+F7vJAQAJ7YE+AwoGg5KkjIwMSVJjY6P6+vpUVFQU3mfGjBmaNGmS6uvr7/g9ent7FQqFIjYAQPIbcID6+/u1adMmzZs3TzNnzpQkdXR0KCUlRenp6RH7Zmdnq6Oj447fp7KyUj6fL7zl5uYOdCQAQAIZcIDKysp0/vx57d+//4EGqKioUDAYDG+XLl16oO8HAEgMA/o5oI0bN+rIkSM6efKkJk6cGH7e7/fr5s2b6urqirgK6uzslN/vv+P38nq98nq9AxkDAJDAoroCcs5p48aNOnjwoE6cOKG8vLyI1+fMmaPRo0erpqYm/FxTU5MuXryowsLC2EwMAEgKUV0BlZWVae/evTp8+LBSU1PDn+v4fD6NHTtWPp9Pa9asUXl5uTIyMpSWlqaXXnpJhYWF3AEHAIgQVYB27dolSVqwYEHE83v27NHq1aslSb/61a80YsQILV++XL29vSouLtavf/3rmAwLAEgeUQXIOXfPfcaMGaOqqipVVVUNeCgAQPJjLTgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMjLIeAAASyYdtZ6PavzjwRFzmSAZcAQEATBAgAICJqAJUWVmpJ598UqmpqcrKytLSpUvV1NQUsc+CBQvk8XgitvXr18d0aABA4osqQHV1dSorK1NDQ4OOHTumvr4+LVq0SD09PRH7rV27Vu3t7eFt+/btMR0aAJD4oroJ4ejRoxGPq6urlZWVpcbGRs2fPz/8/Lhx4+T3+2MzIQAgKT3QZ0DBYFCSlJGREfH8u+++q8zMTM2cOVMVFRW6fv36Xb9Hb2+vQqFQxAYASH4Dvg27v79fmzZt0rx58zRz5szw888//7wmT56sQCCgc+fO6dVXX1VTU5Pef//9O36fyspKbdu2baBjAAASlMc55wbyhRs2bNAf//hHffzxx5o4ceJd9ztx4oQWLlyo5uZmTZ069bbXe3t71dvbG34cCoWUm5urBVqiUZ7RAxkNAO5btD/XE63h+HNAn7s+1eqwgsGg0tLS7rrfgK6ANm7cqCNHjujkyZNfGx9JKigokKS7Bsjr9crr9Q5kDABAAosqQM45vfTSSzp48KBqa2uVl5d3z685e/asJCknJ2dAAwIAklNUASorK9PevXt1+PBhpaamqqOjQ5Lk8/k0duxYtbS0aO/evfrhD3+o8ePH69y5c9q8ebPmz5+v2bNnx+UAAACJKaoA7dq1S9IXP2z63/bs2aPVq1crJSVFx48f186dO9XT06Pc3FwtX75cr7/+eswGBgAkhwHfhBAvoVBIPp+PmxAADEnxvmlBSvwbF+73JgTWggMAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGBiwL8RFQCGo2jXaRvI2nHD5ZfkcQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABGvBAUAcDca6a9GuHRfvteZC3f16ZNq99+MKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwWKkAJDghuKCp/eDKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMDHkluJxzkmSPlef5IyHAQBIkkLd/fe/77Uv9v3y7/O7GXIB6u7uliR9rP8xngQA8KVHpkX/Nd3d3fL5fHd93ePulahB1t/fr7a2NqWmpsrj8US8FgqFlJubq0uXLiktLc1owsE1HI9ZGp7HPRyPWeK4k/G4nXPq7u5WIBDQiBF3/6RnyF0BjRgxQhMnTvzafdLS0pLuhN3LcDxmaXge93A8ZonjTjZfd+XzJW5CAACYIEAAABMJFSCv16utW7fK6/VajzJohuMxS8PzuIfjMUsc93A77v825G5CAAAMDwl1BQQASB4ECABgggABAEwQIACAiYQJUFVVlb75zW9qzJgxKigo0F/+8hfrkeLqzTfflMfjidhmzJhhPVZMnTx5UosXL1YgEJDH49GhQ4ciXnfOacuWLcrJydHYsWNVVFSkCxcu2AwbQ/c67tWrV9927ktKSmyGjZHKyko9+eSTSk1NVVZWlpYuXaqmpqaIfW7cuKGysjKNHz9eDz/8sJYvX67Ozk6jiWPjfo57wYIFt53v9evXG008uBIiQO+9957Ky8u1detWffLJJ8rPz1dxcbGuXLliPVpcPf7442pvbw9vH3/8sfVIMdXT06P8/HxVVVXd8fXt27fr7bff1u7du3Xq1Ck99NBDKi4u1o0bNwZ50ti613FLUklJScS537dv3yBOGHt1dXUqKytTQ0ODjh07pr6+Pi1atEg9PT3hfTZv3qwPPvhABw4cUF1dndra2rRs2TLDqR/c/Ry3JK1duzbifG/fvt1o4kHmEsDcuXNdWVlZ+PGtW7dcIBBwlZWVhlPF19atW11+fr71GINGkjt48GD4cX9/v/P7/e6tt94KP9fV1eW8Xq/bt2+fwYTx8dXjds65VatWuSVLlpjMM1iuXLniJLm6ujrn3BfndvTo0e7AgQPhff72t785Sa6+vt5qzJj76nE759z3v/999+Mf/9huKEND/gro5s2bamxsVFFRUfi5ESNGqKioSPX19YaTxd+FCxcUCAQ0ZcoUvfDCC7p48aL1SIOmtbVVHR0dEefd5/OpoKAg6c+7JNXW1iorK0vTp0/Xhg0bdPXqVeuRYioYDEqSMjIyJEmNjY3q6+uLON8zZszQpEmTkup8f/W4v/Tuu+8qMzNTM2fOVEVFha5fv24x3qAbcouRftWnn36qW7duKTs7O+L57Oxs/f3vfzeaKv4KCgpUXV2t6dOnq729Xdu2bdPTTz+t8+fPKzU11Xq8uOvo6JCkO573L19LViUlJVq2bJny8vLU0tKi1157TaWlpaqvr9fIkSOtx3tg/f392rRpk+bNm6eZM2dK+uJ8p6SkKD09PWLfZDrfdzpuSXr++ec1efJkBQIBnTt3Tq+++qqampr0/vvvG047OIZ8gIar0tLS8J9nz56tgoICTZ48WX/4wx+0Zs0aw8kQbytXrgz/edasWZo9e7amTp2q2tpaLVy40HCy2CgrK9P58+eT7jPNe7nbca9bty7851mzZiknJ0cLFy5US0uLpk6dOthjDqoh/09wmZmZGjly5G13w3R2dsrv9xtNNfjS09M1bdo0NTc3W48yKL48t8P9vEvSlClTlJmZmRTnfuPGjTpy5Ig++uijiF+74vf7dfPmTXV1dUXsnyzn+27HfScFBQWSlBTn+16GfIBSUlI0Z84c1dTUhJ/r7+9XTU2NCgsLDScbXNeuXVNLS4tycnKsRxkUeXl58vv9Eec9FArp1KlTw+q8S9Lly5d19erVhD73zjlt3LhRBw8e1IkTJ5SXlxfx+pw5czR69OiI893U1KSLFy8m9Pm+13HfydmzZyUpoc/3fbO+C+J+7N+/33m9XlddXe3++te/unXr1rn09HTX0dFhPVrc/OQnP3G1tbWutbXV/elPf3JFRUUuMzPTXblyxXq0mOnu7nZnzpxxZ86ccZLcjh073JkzZ9w///lP55xzv/zlL116ero7fPiwO3funFuyZInLy8tzn332mfHkD+brjru7u9u9/PLLrr6+3rW2trrjx4+773znO+6xxx5zN27csB59wDZs2OB8Pp+rra117e3t4e369evhfdavX+8mTZrkTpw44U6fPu0KCwtdYWGh4dQP7l7H3dzc7H72s5+506dPu9bWVnf48GE3ZcoUN3/+fOPJB0dCBMg559555x03adIkl5KS4ubOnesaGhqsR4qrFStWuJycHJeSkuK+8Y1vuBUrVrjm5mbrsWLqo48+cpJu21atWuWc++JW7DfeeMNlZ2c7r9frFi5c6JqammyHjoGvO+7r16+7RYsWuQkTJrjRo0e7yZMnu7Vr1yb8/2zd6XgluT179oT3+eyzz9yPfvQj98gjj7hx48a5Z5991rW3t9sNHQP3Ou6LFy+6+fPnu4yMDOf1et2jjz7qfvrTn7pgMGg7+CDh1zEAAEwM+c+AAADJiQABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAw8X/jmiEAdp2JOgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZMElEQVR4nO3df0xV9/3H8df1B1dt4VJEuNyJDm3VrSrNXGXE1tlJBJYYrf6hbf/Qxmh02ExZ14am1botYbGJM22Y/rPJllTtTKqm5jsbxYLpBi5SjTHb+AphU8MPVzPuRaxI5fP9o9/e7VatXryXN/fyfCQn8d574L5PTuOzx3v44HHOOQEAMMhGWA8AABieCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAxynqAr+rv71dbW5tSU1Pl8XisxwEARMk5p+7ubgUCAY0YcffrnCEXoLa2NuXm5lqPAQB4QJcuXdLEiRPv+vqQC1Bqaqok6Z+ffFNpD9/fvxA+O21WPEcCAEThc/XpY/1P+O/zu4lbgKqqqvTWW2+po6ND+fn5eueddzR37tx7ft2X/+yW9vAIpaXeX4BGeUY/0KwAgBj6/xVG7/UxSlxuQnjvvfdUXl6urVu36pNPPlF+fr6Ki4t15cqVeLwdACABxSVAO3bs0Nq1a/Xiiy/q29/+tnbv3q1x48bpt7/9bTzeDgCQgGIeoJs3b6qxsVFFRUX/eZMRI1RUVKT6+vrb9u/t7VUoFIrYAADJL+YB+vTTT3Xr1i1lZ2dHPJ+dna2Ojo7b9q+srJTP5wtv3AEHAMOD+Q+iVlRUKBgMhrdLly5ZjwQAGAQxvwsuMzNTI0eOVGdnZ8TznZ2d8vv9t+3v9Xrl9XpjPQYAYIiL+RVQSkqK5syZo5qamvBz/f39qqmpUWFhYazfDgCQoOLyc0Dl5eVatWqVvvvd72ru3LnauXOnenp69OKLL8bj7QAACSguAVqxYoX+9a9/acuWLero6NATTzyho0eP3nZjAgBg+PI455z1EP8tFArJ5/Pp3/875b5XQigOPBHfoQAA9+1z16daHVYwGFRaWtpd9zO/Cw4AMDwRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAiZgH6M0335TH44nYZsyYEeu3AQAkuFHx+KaPP/64jh8//p83GRWXtwEAJLC4lGHUqFHy+/3x+NYAgCQRl8+ALly4oEAgoClTpuiFF17QxYsX77pvb2+vQqFQxAYASH4xD1BBQYGqq6t19OhR7dq1S62trXr66afV3d19x/0rKyvl8/nCW25ubqxHAgAMQR7nnIvnG3R1dWny5MnasWOH1qxZc9vrvb296u3tDT8OhULKzc3Vv/93itJS76+PxYEnYjUuAOABfe76VKvDCgaDSktLu+t+cb87ID09XdOmTVNzc/MdX/d6vfJ6vfEeAwAwxMT954CuXbumlpYW5eTkxPutAAAJJOYBevnll1VXV6d//OMf+vOf/6xnn31WI0eO1HPPPRfrtwIAJLCY/xPc5cuX9dxzz+nq1auaMGGCnnrqKTU0NGjChAmxfisAQAKLeYD2798f628JAEhCrAUHADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADARNx/H9Bg+LDtrPUICYFf3AdgKOEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwMWQXI3122iyN8oy2HmNIi3YR1oEs2soCpgDihSsgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJobsWnC4t2jXaRvIWnAD+ZqhhvXsgKGJKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmWAtuGEmWNdGiXZ8uGdazw/1Jlv/GhwuugAAAJggQAMBE1AE6efKkFi9erEAgII/Ho0OHDkW87pzTli1blJOTo7Fjx6qoqEgXLlyI1bwAgCQRdYB6enqUn5+vqqqqO76+fft2vf3229q9e7dOnTqlhx56SMXFxbpx48YDDwsASB5R34RQWlqq0tLSO77mnNPOnTv1+uuva8mSJZKk3//+98rOztahQ4e0cuXKB5sWAJA0YvoZUGtrqzo6OlRUVBR+zufzqaCgQPX19Xf8mt7eXoVCoYgNAJD8Yhqgjo4OSVJ2dnbE89nZ2eHXvqqyslI+ny+85ebmxnIkAMAQZX4XXEVFhYLBYHi7dOmS9UgAgEEQ0wD5/X5JUmdnZ8TznZ2d4de+yuv1Ki0tLWIDACS/mAYoLy9Pfr9fNTU14edCoZBOnTqlwsLCWL4VACDBRX0X3LVr19Tc3Bx+3NraqrNnzyojI0OTJk3Spk2b9Itf/EKPPfaY8vLy9MYbbygQCGjp0qWxnBsAkOCiDtDp06f1zDPPhB+Xl5dLklatWqXq6mq98sor6unp0bp169TV1aWnnnpKR48e1ZgxY2I3NQAg4Xmcc856iP8WCoXk8/m0QEs0yjPaehwAhqJdSJbFSIeGz12fanVYwWDwaz/XN78LDgAwPBEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAR9WKkADBURbt2nMT6cZa4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCteAADFnRrtM2kLXgYIcrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADARdYBOnjypxYsXKxAIyOPx6NChQxGvr169Wh6PJ2IrKSmJ1bwAgCQRdYB6enqUn5+vqqqqu+5TUlKi9vb28LZv374HGhIAkHxGRfsFpaWlKi0t/dp9vF6v/H7/gIcCACS/uHwGVFtbq6ysLE2fPl0bNmzQ1atX4/E2AIAEFvUV0L2UlJRo2bJlysvLU0tLi1577TWVlpaqvr5eI0eOvG3/3t5e9fb2hh+HQqFYjwQAGIJiHqCVK1eG/zxr1izNnj1bU6dOVW1trRYuXHjb/pWVldq2bVusxwAADHFxvw17ypQpyszMVHNz8x1fr6ioUDAYDG+XLl2K90gAgCEg5ldAX3X58mVdvXpVOTk5d3zd6/XK6/XGewwAwBATdYCuXbsWcTXT2tqqs2fPKiMjQxkZGdq2bZuWL18uv9+vlpYWvfLKK3r00UdVXFwc08EBAIkt6gCdPn1azzzzTPhxeXm5JGnVqlXatWuXzp07p9/97nfq6upSIBDQokWL9POf/5yrHABAhKgDtGDBAjnn7vr6hx9++EADAQCGB9aCAwCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMjLIeAAAsfdh2Nqr9iwNPxGWO4YgrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABIuRAkgaA1koNNrFSBE7XAEBAExEFaDKyko9+eSTSk1NVVZWlpYuXaqmpqaIfW7cuKGysjKNHz9eDz/8sJYvX67Ozs6YDg0ASHxRBaiurk5lZWVqaGjQsWPH1NfXp0WLFqmnpye8z+bNm/XBBx/owIEDqqurU1tbm5YtWxbzwQEAiS2qz4COHj0a8bi6ulpZWVlqbGzU/PnzFQwG9Zvf/EZ79+7VD37wA0nSnj179K1vfUsNDQ363ve+F7vJAQAJ7YE+AwoGg5KkjIwMSVJjY6P6+vpUVFQU3mfGjBmaNGmS6uvr7/g9ent7FQqFIjYAQPIbcID6+/u1adMmzZs3TzNnzpQkdXR0KCUlRenp6RH7Zmdnq6Oj447fp7KyUj6fL7zl5uYOdCQAQAIZcIDKysp0/vx57d+//4EGqKioUDAYDG+XLl16oO8HAEgMA/o5oI0bN+rIkSM6efKkJk6cGH7e7/fr5s2b6urqirgK6uzslN/vv+P38nq98nq9AxkDAJDAoroCcs5p48aNOnjwoE6cOKG8vLyI1+fMmaPRo0erpqYm/FxTU5MuXryowsLC2EwMAEgKUV0BlZWVae/evTp8+LBSU1PDn+v4fD6NHTtWPp9Pa9asUXl5uTIyMpSWlqaXXnpJhYWF3AEHAIgQVYB27dolSVqwYEHE83v27NHq1aslSb/61a80YsQILV++XL29vSouLtavf/3rmAwLAEgeUQXIOXfPfcaMGaOqqipVVVUNeCgAQPJjLTgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMjLIeAAASyYdtZ6PavzjwRFzmSAZcAQEATBAgAICJqAJUWVmpJ598UqmpqcrKytLSpUvV1NQUsc+CBQvk8XgitvXr18d0aABA4osqQHV1dSorK1NDQ4OOHTumvr4+LVq0SD09PRH7rV27Vu3t7eFt+/btMR0aAJD4oroJ4ejRoxGPq6urlZWVpcbGRs2fPz/8/Lhx4+T3+2MzIQAgKT3QZ0DBYFCSlJGREfH8u+++q8zMTM2cOVMVFRW6fv36Xb9Hb2+vQqFQxAYASH4Dvg27v79fmzZt0rx58zRz5szw888//7wmT56sQCCgc+fO6dVXX1VTU5Pef//9O36fyspKbdu2baBjAAASlMc55wbyhRs2bNAf//hHffzxx5o4ceJd9ztx4oQWLlyo5uZmTZ069bbXe3t71dvbG34cCoWUm5urBVqiUZ7RAxkNAO5btD/XE63h+HNAn7s+1eqwgsGg0tLS7rrfgK6ANm7cqCNHjujkyZNfGx9JKigokKS7Bsjr9crr9Q5kDABAAosqQM45vfTSSzp48KBqa2uVl5d3z685e/asJCknJ2dAAwIAklNUASorK9PevXt1+PBhpaamqqOjQ5Lk8/k0duxYtbS0aO/evfrhD3+o8ePH69y5c9q8ebPmz5+v2bNnx+UAAACJKaoA7dq1S9IXP2z63/bs2aPVq1crJSVFx48f186dO9XT06Pc3FwtX75cr7/+eswGBgAkhwHfhBAvoVBIPp+PmxAADEnxvmlBSvwbF+73JgTWggMAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGBiwL8RFQCGo2jXaRvI2nHD5ZfkcQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABGvBAUAcDca6a9GuHRfvteZC3f16ZNq99+MKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwWKkAJDghuKCp/eDKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMDHkluJxzkmSPlef5IyHAQBIkkLd/fe/77Uv9v3y7/O7GXIB6u7uliR9rP8xngQA8KVHpkX/Nd3d3fL5fHd93ePulahB1t/fr7a2NqWmpsrj8US8FgqFlJubq0uXLiktLc1owsE1HI9ZGp7HPRyPWeK4k/G4nXPq7u5WIBDQiBF3/6RnyF0BjRgxQhMnTvzafdLS0pLuhN3LcDxmaXge93A8ZonjTjZfd+XzJW5CAACYIEAAABMJFSCv16utW7fK6/VajzJohuMxS8PzuIfjMUsc93A77v825G5CAAAMDwl1BQQASB4ECABgggABAEwQIACAiYQJUFVVlb75zW9qzJgxKigo0F/+8hfrkeLqzTfflMfjidhmzJhhPVZMnTx5UosXL1YgEJDH49GhQ4ciXnfOacuWLcrJydHYsWNVVFSkCxcu2AwbQ/c67tWrV9927ktKSmyGjZHKyko9+eSTSk1NVVZWlpYuXaqmpqaIfW7cuKGysjKNHz9eDz/8sJYvX67Ozk6jiWPjfo57wYIFt53v9evXG008uBIiQO+9957Ky8u1detWffLJJ8rPz1dxcbGuXLliPVpcPf7442pvbw9vH3/8sfVIMdXT06P8/HxVVVXd8fXt27fr7bff1u7du3Xq1Ck99NBDKi4u1o0bNwZ50ti613FLUklJScS537dv3yBOGHt1dXUqKytTQ0ODjh07pr6+Pi1atEg9PT3hfTZv3qwPPvhABw4cUF1dndra2rRs2TLDqR/c/Ry3JK1duzbifG/fvt1o4kHmEsDcuXNdWVlZ+PGtW7dcIBBwlZWVhlPF19atW11+fr71GINGkjt48GD4cX9/v/P7/e6tt94KP9fV1eW8Xq/bt2+fwYTx8dXjds65VatWuSVLlpjMM1iuXLniJLm6ujrn3BfndvTo0e7AgQPhff72t785Sa6+vt5qzJj76nE759z3v/999+Mf/9huKEND/gro5s2bamxsVFFRUfi5ESNGqKioSPX19YaTxd+FCxcUCAQ0ZcoUvfDCC7p48aL1SIOmtbVVHR0dEefd5/OpoKAg6c+7JNXW1iorK0vTp0/Xhg0bdPXqVeuRYioYDEqSMjIyJEmNjY3q6+uLON8zZszQpEmTkup8f/W4v/Tuu+8qMzNTM2fOVEVFha5fv24x3qAbcouRftWnn36qW7duKTs7O+L57Oxs/f3vfzeaKv4KCgpUXV2t6dOnq729Xdu2bdPTTz+t8+fPKzU11Xq8uOvo6JCkO573L19LViUlJVq2bJny8vLU0tKi1157TaWlpaqvr9fIkSOtx3tg/f392rRpk+bNm6eZM2dK+uJ8p6SkKD09PWLfZDrfdzpuSXr++ec1efJkBQIBnTt3Tq+++qqampr0/vvvG047OIZ8gIar0tLS8J9nz56tgoICTZ48WX/4wx+0Zs0aw8kQbytXrgz/edasWZo9e7amTp2q2tpaLVy40HCy2CgrK9P58+eT7jPNe7nbca9bty7851mzZiknJ0cLFy5US0uLpk6dOthjDqoh/09wmZmZGjly5G13w3R2dsrv9xtNNfjS09M1bdo0NTc3W48yKL48t8P9vEvSlClTlJmZmRTnfuPGjTpy5Ig++uijiF+74vf7dfPmTXV1dUXsnyzn+27HfScFBQWSlBTn+16GfIBSUlI0Z84c1dTUhJ/r7+9XTU2NCgsLDScbXNeuXVNLS4tycnKsRxkUeXl58vv9Eec9FArp1KlTw+q8S9Lly5d19erVhD73zjlt3LhRBw8e1IkTJ5SXlxfx+pw5czR69OiI893U1KSLFy8m9Pm+13HfydmzZyUpoc/3fbO+C+J+7N+/33m9XlddXe3++te/unXr1rn09HTX0dFhPVrc/OQnP3G1tbWutbXV/elPf3JFRUUuMzPTXblyxXq0mOnu7nZnzpxxZ86ccZLcjh073JkzZ9w///lP55xzv/zlL116ero7fPiwO3funFuyZInLy8tzn332mfHkD+brjru7u9u9/PLLrr6+3rW2trrjx4+773znO+6xxx5zN27csB59wDZs2OB8Pp+rra117e3t4e369evhfdavX+8mTZrkTpw44U6fPu0KCwtdYWGh4dQP7l7H3dzc7H72s5+506dPu9bWVnf48GE3ZcoUN3/+fOPJB0dCBMg559555x03adIkl5KS4ubOnesaGhqsR4qrFStWuJycHJeSkuK+8Y1vuBUrVrjm5mbrsWLqo48+cpJu21atWuWc++JW7DfeeMNlZ2c7r9frFi5c6JqammyHjoGvO+7r16+7RYsWuQkTJrjRo0e7yZMnu7Vr1yb8/2zd6XgluT179oT3+eyzz9yPfvQj98gjj7hx48a5Z5991rW3t9sNHQP3Ou6LFy+6+fPnu4yMDOf1et2jjz7qfvrTn7pgMGg7+CDh1zEAAEwM+c+AAADJiQABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAw8X/jmiEAdp2JOgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_loader_e = DataLoader(dataset=train_dataset_e, batch_size=200,\n",
    "                                  shuffle=False)\n",
    "d_batch_e, path_batch_e = next(iter(train_loader_e))\n",
    "\n",
    "train_loader_v = DataLoader(dataset=train_dataset_v, batch_size=200,\n",
    "                                  shuffle=False)\n",
    "d_batch_v, path_batch_v = next(iter(train_loader_v))\n",
    "\n",
    "# Choose a sample\n",
    "i = 47\n",
    "\n",
    "plt.imshow(path_batch_v[i,:,:])\n",
    "plt.show()\n",
    "plt.imshow(edge_to_node(path_batch_e[i,:], Edge_list, m, 'cpu'))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
